{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#  &#x1F4D1; **作业 2: 音位分类 (分类)**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "学习目标：\n",
    "* 数据预处理：从原始波形中提取MFCC特征\n",
    "* 分类：使用预提取的MFCC特征执行逐帧音位分类\n",
    "* 熟悉并提高pytorch训练技巧，熟悉pytorch模块\n",
    "相关资料：\n",
    "* Slides地址: https://docs.google.com/presentation/d/1v6HkBWiJb8WNDcJ9_-2kwVstxUWml87b9CnA16Gdoio/edit?usp=sharing\n",
    "* Kaggle地址: https://www.kaggle.com/c/ml2022spring-hw2\n",
    "* 相关课程视频资源也可在B站获取  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "前置知识：\n",
    "- 音位：\n",
    "\n",
    "(phonetics 语音) 音位，音素（区分单词的最小语音单位，英语sip中的s和zip中的z是两个不同的音素）  \n",
    "\n",
    "例如：Machine Learning → M AH SH IH N L ER N IH NG M M M AH AH SH SH IH IH IH N N N N ... Machine\n",
    "\n",
    "- MFCC：  \n",
    "\n",
    "在语音识别（Speech recognition）和话者识别（Speaker recognition）方面，最常用到的语音特征就是梅尔倒谱系数（Mel-scale Frequency Cepstral Coefficients, MFCC）。  \n",
    "\n",
    "根据人耳听觉机理的研究发现，人耳对不同频率的声波有不同的听觉敏感度。从200Hz到5000Hz的语音信号对语音的清晰度影响对大。两个响度不等的声音作用于人耳时，则响度较高的频率成分的存在会影响到对响度较低的频率成分的感受，使其变得不易察觉，这种现象称为掩蔽效应。由于频率较低的声音在内耳蜗基底膜上行波传递的距离（速度）大于频率较高的声音，故一般来说，低音容易掩蔽高音，而高音掩蔽低音较困难。在低频处的声音掩蔽的临界带宽较高频要小。所以，人们从低频到高频这一段频带内按临界带宽的大小由密到疏安排一组带通滤波器，对输入信号进行滤波。将每个带通滤波器输出的信号能量作为信号的基本特征，对此特征经过进一步处理后就可以作为语音的输入特征。由于这种特征不依赖于信号的性质，对输入信号不做任何的假设和限制，又利用了听觉模型的研究成果。因此，这种参数比基于声道模型的LPCC相比具有更好的鲁棒性，更符合人耳的听觉特性，而且当信噪比降低时仍然具有较好的识别性能。  \n",
    "\n",
    "在本次作业中，正常的一段音频素材可能包含大量的音位信息，而音位之间又可能存在重叠干扰的情况，因此我们将一段音频素材每隔10ms切取25ms，以此来尽可能保存完整的音位素材，取出后的素材称为一个frame，取出后的frame并不适合直接进入训练，因此我们要进行进一步的处理，通过MFCC，将它转化为一个39维度的特征，转换后为了更加精确的判断当前特征内的音位信息，我们往往采取其前后的特征来做辅助判断 也就是前向特征与后向特征各取5个，所以我们最后得到的是一个11*39维的一个向量。\n",
    "\n",
    "![](./pic/01.png) \n",
    "想要深入了解实现过程的可以查看下列链接：  \n",
    "[Prof. Hung-Yi Lee[2020Spring DLHLP] Speech Recognition](https://speech.ee.ntu.edu.tw/~tlkagk/courses/DLHLP20/ASR%20(v12).pdf)  \n",
    "[ Prof. Lin-Shan Lee’s[Introduction to Digital Speech Processing]Chap.7](http://ocw.aca.ntu.edu.tw/ntu-ocw/ocw/cou/104S204)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#  数据集下载\n",
    "如果下列命令无法下载，可以到下列地址下载数据\n",
    "- Kaggle下载数据:  [Kaggle: ml2022spring-hw2](https://www.kaggle.com/competitions/ml2022spring-hw2)\n",
    "- 百度云下载数据: [云盘(提取码：05zc)](https://pan.baidu.com/s/198xn8Lk9MjvUsq866mZuuw)\n",
    "\n",
    "\n",
    "下载完成后，你应该能够获取如下文件：\n",
    "- `libriphone/train_split.txt`\n",
    "- `libriphone/train_labels`\n",
    "- `libriphone/test_split.txt`\n",
    "- `libriphone/feat/train/*.pt`: training feature<br>\n",
    "- `libriphone/feat/test/*.pt`:  testing feature<br>  \n",
    "\n",
    "pt文件可以使用torch.load方法导入\n",
    "\n",
    "\n",
    "<b>同学们下载完后直接解压到 HW02文件夹下面（将里面的文件最终放到HW02下）</b>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 下载链接\n",
    "!wget -O libriphone.zip \"https://github.com/xraychen/shiny-robot/releases/download/v1.0/libriphone.zip\"\n",
    "\n",
    "# 下列数据获取方式需要依靠gdown\n",
    "\n",
    "# 备用链接 0\n",
    "# !pip install --upgrade gdown\n",
    "# !gdown --id '1o6Ag-G3qItSmYhTheX6DYiuyNzWyHyTc' --output libriphone.zip\n",
    "\n",
    "# 备用链接 1\n",
    "# !pip install --upgrade gdown\n",
    "# !gdown --id '1R1uQYi4QpX0tBfUWt2mbZcncdBsJkxeW' --output libriphone.zip\n",
    "\n",
    "# 备用链接 2\n",
    "# !wget -O libriphone.zip \"https://www.dropbox.com/s/wqww8c5dbrl2ka9/libriphone.zip?dl=1\"\n",
    "\n",
    "# 备用链接 3\n",
    "# !wget -O libriphone.zip \"https://www.dropbox.com/s/p2ljbtb2bam13in/libriphone.zip?dl=1\"\n",
    "\n",
    "!unzip -q libriphone.zip\n",
    "!ls libriphone"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sat Dec  3 21:46:04 2022       \n",
      "+-----------------------------------------------------------------------------+\n",
      "| NVIDIA-SMI 526.98       Driver Version: 526.98       CUDA Version: 12.0     |\n",
      "|-------------------------------+----------------------+----------------------+\n",
      "| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
      "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
      "|                               |                      |               MIG M. |\n",
      "|===============================+======================+======================|\n",
      "|   0  NVIDIA GeForce ... WDDM  | 00000000:01:00.0  On |                  N/A |\n",
      "| N/A   48C    P5     8W /  N/A |   2444MiB /  4096MiB |     25%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "                                                                               \n",
      "+-----------------------------------------------------------------------------+\n",
      "| Processes:                                                                  |\n",
      "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
      "|        ID   ID                                                   Usage      |\n",
      "|=============================================================================|\n",
      "|    0   N/A  N/A      3008    C+G   ...2gh52qy24etm\\Nahimic3.exe    N/A      |\n",
      "|    0   N/A  N/A      3960    C+G   ...ge\\Application\\msedge.exe    N/A      |\n",
      "|    0   N/A  N/A      8136    C+G   C:\\Windows\\explorer.exe         N/A      |\n",
      "|    0   N/A  N/A     10176    C+G   ...n1h2txyewy\\SearchHost.exe    N/A      |\n",
      "|    0   N/A  N/A     10224    C+G   ...artMenuExperienceHost.exe    N/A      |\n",
      "|    0   N/A  N/A     10992    C+G   ...y\\ShellExperienceHost.exe    N/A      |\n",
      "|    0   N/A  N/A     11364    C+G   ...perience\\NVIDIA Share.exe    N/A      |\n",
      "|    0   N/A  N/A     12308    C+G   ...in7x64\\steamwebhelper.exe    N/A      |\n",
      "|    0   N/A  N/A     12728    C+G   ...2txyewy\\TextInputHost.exe    N/A      |\n",
      "|    0   N/A  N/A     12868    C+G   ...cw5n1h2txyewy\\LockApp.exe    N/A      |\n",
      "|    0   N/A  N/A     13060    C+G   ...perience\\NVIDIA Share.exe    N/A      |\n",
      "|    0   N/A  N/A     16204    C+G   ...er_engine\\wallpaper32.exe    N/A      |\n",
      "|    0   N/A  N/A     16548    C+G   ...8wekyb3d8bbwe\\Cortana.exe    N/A      |\n",
      "|    0   N/A  N/A     16680    C+G   ...oft OneDrive\\OneDrive.exe    N/A      |\n",
      "|    0   N/A  N/A     18320    C+G   ...d\\runtime\\WeChatAppEx.exe    N/A      |\n",
      "|    0   N/A  N/A     20400    C+G   ...tracted\\WechatBrowser.exe    N/A      |\n",
      "|    0   N/A  N/A     21812    C+G   ...ekyb3d8bbwe\\HxOutlook.exe    N/A      |\n",
      "|    0   N/A  N/A     22332    C+G   ...lPanel\\SystemSettings.exe    N/A      |\n",
      "|    0   N/A  N/A     22568    C+G   ...icrosoft VS Code\\Code.exe    N/A      |\n",
      "|    0   N/A  N/A     30848    C+G   D:\\CloudMusic\\cloudmusic.exe    N/A      |\n",
      "|    0   N/A  N/A     31160    C+G   ...418.62\\msedgewebview2.exe    N/A      |\n",
      "+-----------------------------------------------------------------------------+\n"
     ]
    }
   ],
   "source": [
    "# 输入如下指令查看GPU状态\n",
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 准备数据"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Helper函数用于预处理来自每个话语的原始MFCC特征的训练数据**\n",
    "\n",
    "\n",
    "一个音位可能跨越几个帧，并且取决于过去和将来的帧\n",
    "\n",
    "因此，我们连接相邻的音位进行训练以获得更高的准确性。**concat_fatte**函数连接过去和未来的k帧（总共2k+1＝n帧），我们预测中心帧。\n",
    "\n",
    "\n",
    "可以随意修改数据预处理函数，但**不要删除任何帧**（如果您修改函数，请记住检查帧的数量是否与幻灯片中提到的相同）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import random\n",
    "import pandas as pd\n",
    "#导入pytorch\n",
    "import torch\n",
    "#导入进度条\n",
    "from tqdm import tqdm\n",
    "# 定义导入feature函数\n",
    "def load_feat(path):\n",
    "    feat = torch.load(path)\n",
    "    return feat\n",
    "\n",
    "def shift(x, n):\n",
    "    if n < 0:\n",
    "        left = x[0].repeat(-n, 1)\n",
    "        right = x[:n]\n",
    "\n",
    "    elif n > 0:\n",
    "        right = x[-1].repeat(n, 1)\n",
    "        left = x[n:]\n",
    "    else:\n",
    "        return x\n",
    "\n",
    "    return torch.cat((left, right), dim=0)\n",
    "# 将前后的特征联系在一起，如concat_n = 11 则前后都接5\n",
    "def concat_feat(x, concat_n):\n",
    "    assert concat_n % 2 == 1 # n 必须是奇数\n",
    "    if concat_n < 2:\n",
    "        return x\n",
    "    seq_len, feature_dim = x.size(0), x.size(1)\n",
    "    x = x.repeat(1, concat_n) \n",
    "    x = x.view(seq_len, concat_n, feature_dim).permute(1, 0, 2) # concat_n, seq_len, feature_dim\n",
    "    mid = (concat_n // 2)\n",
    "    for r_idx in range(1, mid+1):\n",
    "        x[mid + r_idx, :] = shift(x[mid + r_idx], r_idx)\n",
    "        x[mid - r_idx, :] = shift(x[mid - r_idx], -r_idx)\n",
    "\n",
    "    return x.permute(1, 0, 2).view(seq_len, concat_n * feature_dim)\n",
    "\n",
    "# 数据预处理函数\n",
    "def preprocess_data(split, feat_dir, phone_path, concat_nframes, train_ratio=0.8, train_val_seed=1337):\n",
    "    class_num = 41 # NOTE: 预先计算，不需要更改\n",
    "    mode = 'train' if (split == 'train' or split == 'val') else 'test'\n",
    "\n",
    "    label_dict = {}\n",
    "    if mode != 'test':\n",
    "      phone_file = open(os.path.join(phone_path, f'{mode}_labels.txt')).readlines()\n",
    "\n",
    "      for line in phone_file:\n",
    "          line = line.strip('\\n').split(' ')\n",
    "          label_dict[line[0]] = [int(p) for p in line[1:]]\n",
    "\n",
    "    if split == 'train' or split == 'val':\n",
    "        # 分割训练和验证数据\n",
    "        usage_list = open(os.path.join(phone_path, 'train_split.txt')).readlines()\n",
    "        random.seed(train_val_seed)\n",
    "        random.shuffle(usage_list)\n",
    "        percent = int(len(usage_list) * train_ratio)\n",
    "        usage_list = usage_list[:percent] if split == 'train' else usage_list[percent:]\n",
    "    elif split == 'test':\n",
    "        usage_list = open(os.path.join(phone_path, 'test_split.txt')).readlines()\n",
    "    else:\n",
    "        raise ValueError('Invalid \\'split\\' argument for dataset: PhoneDataset!')\n",
    "\n",
    "    usage_list = [line.strip('\\n') for line in usage_list]\n",
    "    print('[Dataset] - # phone classes: ' + str(class_num) + ', number of utterances for ' + split + ': ' + str(len(usage_list)))\n",
    "\n",
    "    max_len = 3000000\n",
    "    X = torch.empty(max_len, 39 * concat_nframes)\n",
    "    if mode != 'test':\n",
    "      y = torch.empty(max_len, dtype=torch.long)\n",
    "\n",
    "    idx = 0\n",
    "    for i, fname in tqdm(enumerate(usage_list)):\n",
    "        feat = load_feat(os.path.join(feat_dir, mode, f'{fname}.pt'))\n",
    "        cur_len = len(feat)\n",
    "        feat = concat_feat(feat, concat_nframes)\n",
    "        if mode != 'test':\n",
    "          label = torch.LongTensor(label_dict[fname])\n",
    "\n",
    "        X[idx: idx + cur_len, :] = feat\n",
    "        if mode != 'test':\n",
    "          y[idx: idx + cur_len] = label\n",
    "\n",
    "        idx += cur_len\n",
    "\n",
    "    X = X[:idx, :]\n",
    "    if mode != 'test':\n",
    "      y = y[:idx]\n",
    "\n",
    "    print(f'[INFO] {split} set')\n",
    "    print(X.shape)\n",
    "    if mode != 'test':\n",
    "      print(y.shape)\n",
    "      return X, y\n",
    "    else:\n",
    "      return X\n",
    "# 返回的X代表数据的维度，如果不链接则为39 如果链接即为n*39 n为连接的特征总数,y为标签"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "#导入数据集\n",
    "from torch.utils.data import Dataset\n",
    "#导入数据加载工具Dataloader\n",
    "from torch.utils.data import DataLoader\n",
    "#定义数据集，一个数据集类应该包含初始化，_getitem__（获取一个元素）以及__len__（获取数据长度）方法\n",
    "class LibriDataset(Dataset):\n",
    "    def __init__(self, X, y=None):\n",
    "        self.data = X\n",
    "        if y is not None:\n",
    "            self.label = torch.LongTensor(y)\n",
    "        else:\n",
    "            self.label = None\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        if self.label is not None:\n",
    "            return self.data[idx], self.label[idx]\n",
    "        else:\n",
    "            return self.data[idx]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#  &#x2728; 神经网络模型\n",
    "<font color=darkred><b>***TODO***: 使用近似相同数量的参数实现2个模型，（A）一个更窄和更深（例如hidden_layers=6，hidden_dim＝1024）和（B）另一个更宽和更浅（例如 hidden_layers＝2、hidden_dim＝1700）。报告两种模型的训练/验证精度。</font></b>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<font color=darkred><b>***TODO***:  添加dropout层，并报告dropout率分别等于（A）0.25/（B）0.5/（C）0.75的训练/验证准确性。</font></b>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Dropout层在神经网络层当中是用来干什么的呢？它是一种可以用于减少神经网络过拟合的结构。\n",
    "![](./pic/02.png)   \n",
    "如上图我们定义的网络,一共有四个输入x_i，一个输出y。Dropout则是在每一个batch的训练当中随机减掉一些神经元，而作为编程者，我们可以设定每一层dropout（将神经元去除的的多少）的概率，在设定之后，就可以得到第一个batch进行训练的结果：  \n",
    "![](./pic/03.png)   \n",
    "从上图我们可以看到一些神经元之间断开了连接，因此它们被dropout了！dropout顾名思义就是被拿掉的意思，正因为我们在神经网络当中拿掉了一些神经元，所以才叫做dropout层。\n",
    "在进行第一个batch的训练时，有以下步骤：\n",
    "* 设定每一个神经网络层进行dropout的概率\n",
    "* 根据相应的概率拿掉一部分的神经元，然后开始训练，更新没有被拿掉神经元以及权重的参数，将其保留\n",
    "* 参数全部更新之后，又重新根据相应的概率拿掉一部分神经元，然后开始训练，如果新用于训练的神经元已经在第一次当中训练过，那么我们继续更新它的参数。而第二次被剪掉的神经元，同时第一次已经更新过参数的，我们保留它的权重，不做修改，直到第n次batch进行dropout时没有将其删除。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "PS: 上面的两个TODO是可以改进其他部分来提高你的成绩的方法。  \n",
    "如下的策略是助教给出的几个优化方式：  \n",
    "● (1%) Simple baseline: 0.45797 (sample code)  \n",
    "● (1%) Medium baseline: 0.69747 (concat n frames, add layers)  \n",
    "● (1%) Strong baseline: 0.75028 (concat n, batchnorm, dropout, add layers)  \n",
    "● (1%) Boss baseline: 0.82324 (sequence-labeling(using RNN))  \n",
    "对于boss baseline，您可以参考RNN之前的课程记录"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "# 建立神经网络\n",
    "class BasicBlock(nn.Module):# 继承 torch 的 Module\n",
    "    def __init__(self, input_dim, output_dim):\n",
    "        super(BasicBlock, self).__init__()\n",
    "\n",
    "        self.block = nn.Sequential(\n",
    "            nn.Linear(input_dim, output_dim),\n",
    "            nn.ReLU(),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.block(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "class Classifier(nn.Module):\n",
    "    def __init__(self, input_dim, output_dim=41, hidden_layers=1, hidden_dim=256):\n",
    "        super(Classifier, self).__init__()\n",
    "\n",
    "        self.fc = nn.Sequential(\n",
    "            BasicBlock(input_dim, hidden_dim),\n",
    "            *[BasicBlock(hidden_dim, hidden_dim) for _ in range(hidden_layers)], # *[]将循环得到的解压\n",
    "            nn.Linear(hidden_dim, output_dim)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.fc(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 超参数定义"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<font color=darkred><b>***TODO***:  可以考虑进一步优化超参数来提高准确率。</font></b>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data prarameters\n",
    "# 用于数据处理时的参数\n",
    "concat_nframes = 1              # 要连接的帧数,n必须为奇数（总共2k+1=n帧）\n",
    "train_ratio = 0.8               # 用于训练的数据比率，其余数据将用于验证\n",
    "# training parameters\n",
    "# 训练过程中的参数\n",
    "seed = 0                        # 随机种子\n",
    "batch_size = 512                # 批次数目\n",
    "num_epoch = 5                   # 训练epoch数\n",
    "learning_rate = 0.0001          # 学习率\n",
    "model_path = './model.ckpt'     # 选择保存检查点的路径（即下文调用保存模型函数的保存位置）\n",
    "# model parameters\n",
    "# 模型参数\n",
    "input_dim = 39 * concat_nframes # 模型的输入维度，不应更改该值，这个值由上面的拼接函数决定\n",
    "hidden_layers = 1               # hidden_layer层的数量\n",
    "hidden_dim = 256                # 隐藏维度"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 准备数据与模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Dataset] - # phone classes: 41, number of utterances for train: 3428\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3428it [00:01, 1903.56it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO] train set\n",
      "torch.Size([2116368, 39])\n",
      "torch.Size([2116368])\n",
      "[Dataset] - # phone classes: 41, number of utterances for val: 858\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "858it [00:00, 1994.05it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO] val set\n",
      "torch.Size([527790, 39])\n",
      "torch.Size([527790])\n"
     ]
    }
   ],
   "source": [
    "# 引入gc模块进行垃圾回收\n",
    "import gc\n",
    "\n",
    "# 预处理数据\n",
    "train_X, train_y = preprocess_data(split='train', feat_dir='./libriphone/feat', phone_path='./libriphone', concat_nframes=concat_nframes, train_ratio=train_ratio)\n",
    "val_X, val_y = preprocess_data(split='val', feat_dir='./libriphone/feat', phone_path='./libriphone', concat_nframes=concat_nframes, train_ratio=train_ratio)\n",
    "\n",
    "# 将数据导入\n",
    "train_set = LibriDataset(train_X, train_y)\n",
    "val_set = LibriDataset(val_X, val_y)\n",
    "\n",
    "# 删除原始数据以节省内存\n",
    "del train_X, train_y, val_X, val_y\n",
    "gc.collect()\n",
    "\n",
    "# 利用dataloader加载数据\n",
    "train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)\n",
    "val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DEVICE: cuda:0\n"
     ]
    }
   ],
   "source": [
    "# 检查当前是否有可用的GPU 否则使用CPU\n",
    "device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
    "print(f'DEVICE: {device}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# 固定随机种子\n",
    "def same_seeds(seed): # 固定随机种子（CPU）\n",
    "    torch.manual_seed(seed) # 固定随机种子（GPU)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed(seed) # 为当前GPU设置\n",
    "        torch.cuda.manual_seed_all(seed)  # 为所有GPU设置\n",
    "    np.random.seed(seed)  # 保证后续使用random函数时，产生固定的随机数\n",
    "    torch.backends.cudnn.benchmark = False # GPU、网络结构固定，可设置为True\n",
    "    torch.backends.cudnn.deterministic = True # 固定网络结构"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 固定随机种子\n",
    "same_seeds(seed)\n",
    "\n",
    "# 创建模型、定义损失函数和优化器\n",
    "model = Classifier(input_dim=input_dim, hidden_layers=hidden_layers, hidden_dim=hidden_dim).to(device)\n",
    "criterion = nn.CrossEntropyLoss() \n",
    "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 4134/4134 [00:30<00:00, 133.73it/s]\n",
      "100%|██████████| 1031/1031 [00:03<00:00, 295.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[001/005] Train Acc: 0.421898 Loss: 2.086656 | Val Acc: 0.440558 loss: 1.971946\n",
      "saving model with acc 0.441\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 4134/4134 [00:27<00:00, 148.92it/s]\n",
      "100%|██████████| 1031/1031 [00:03<00:00, 290.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[002/005] Train Acc: 0.449010 Loss: 1.934501 | Val Acc: 0.449703 loss: 1.926872\n",
      "saving model with acc 0.450\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 4134/4134 [00:27<00:00, 150.28it/s]\n",
      "100%|██████████| 1031/1031 [00:03<00:00, 292.12it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[003/005] Train Acc: 0.455098 Loss: 1.904368 | Val Acc: 0.453834 loss: 1.908885\n",
      "saving model with acc 0.454\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 4134/4134 [00:27<00:00, 150.31it/s]\n",
      "100%|██████████| 1031/1031 [00:03<00:00, 284.90it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[004/005] Train Acc: 0.458366 Loss: 1.888013 | Val Acc: 0.456125 loss: 1.896527\n",
      "saving model with acc 0.456\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 4134/4134 [00:27<00:00, 150.50it/s]\n",
      "100%|██████████| 1031/1031 [00:03<00:00, 288.22it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[005/005] Train Acc: 0.460736 Loss: 1.876455 | Val Acc: 0.457858 loss: 1.889775\n",
      "saving model with acc 0.458\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "best_acc = 0.0\n",
    "for epoch in range(num_epoch):\n",
    "    train_acc = 0.0\n",
    "    train_loss = 0.0\n",
    "    val_acc = 0.0\n",
    "    val_loss = 0.0\n",
    "    \n",
    "    # 训练部分\n",
    "    model.train() # 设定模型到训练模式\n",
    "    for i, batch in enumerate(tqdm(train_loader)):\n",
    "        features, labels = batch\n",
    "        features = features.to(device)\n",
    "        labels = labels.to(device)\n",
    "        \n",
    "        optimizer.zero_grad() \n",
    "        outputs = model(features) \n",
    "        \n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward() \n",
    "        optimizer.step() \n",
    "        \n",
    "        _, train_pred = torch.max(outputs, 1) # 获得概率最高的类的索引\n",
    "        train_acc += (train_pred.detach() == labels.detach()).sum().item()\n",
    "        train_loss += loss.item()\n",
    "    \n",
    "    # 验证部分\n",
    "    if len(val_set) > 0:\n",
    "        model.eval() # 设定模型到评估模式\n",
    "        with torch.no_grad():\n",
    "            for i, batch in enumerate(tqdm(val_loader)):\n",
    "                features, labels = batch\n",
    "                features = features.to(device)\n",
    "                labels = labels.to(device)\n",
    "                outputs = model(features)\n",
    "                \n",
    "                loss = criterion(outputs, labels) \n",
    "                \n",
    "                _, val_pred = torch.max(outputs, 1) \n",
    "                val_acc += (val_pred.cpu() == labels.cpu()).sum().item() # 获得概率最高的类的索引\n",
    "                val_loss += loss.item()\n",
    "\n",
    "            print('[{:03d}/{:03d}] Train Acc: {:3.6f} Loss: {:3.6f} | Val Acc: {:3.6f} loss: {:3.6f}'.format(\n",
    "                epoch + 1, num_epoch, train_acc/len(train_set), train_loss/len(train_loader), val_acc/len(val_set), val_loss/len(val_loader)\n",
    "            ))\n",
    "\n",
    "            # 如果模型获得提升，在此阶段保存模型\n",
    "            if val_acc > best_acc:\n",
    "                best_acc = val_acc\n",
    "                torch.save(model.state_dict(), model_path)\n",
    "                print('saving model with acc {:.3f}'.format(best_acc/len(val_set)))\n",
    "    else:\n",
    "        print('[{:03d}/{:03d}] Train Acc: {:3.6f} Loss: {:3.6f}'.format(\n",
    "            epoch + 1, num_epoch, train_acc/len(train_set), train_loss/len(train_loader)\n",
    "        ))\n",
    "\n",
    "# 如果结束验证，则保存最后一个epoch得到的模型\n",
    "if len(val_set) == 0:\n",
    "    torch.save(model.state_dict(), model_path)\n",
    "    print('saving model at last epoch')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "del train_loader, val_loader\n",
    "gc.collect()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试\n",
    "创建测试数据集，并从保存的检查点加载模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Dataset] - # phone classes: 41, number of utterances for test: 1078\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1078it [00:00, 1188.21it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO] test set\n",
      "torch.Size([646268, 39])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# 载入数据\n",
    "test_X = preprocess_data(split='test', feat_dir='./libriphone/feat', phone_path='./libriphone', concat_nframes=concat_nframes)\n",
    "test_set = LibriDataset(test_X, None)\n",
    "test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 加载已经训练好的模型\n",
    "model = Classifier(input_dim=input_dim, hidden_layers=hidden_layers, hidden_dim=hidden_dim).to(device)\n",
    "model.load_state_dict(torch.load(model_path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1263/1263 [00:02<00:00, 423.45it/s]\n"
     ]
    }
   ],
   "source": [
    "test_acc = 0.0\n",
    "test_lengths = 0\n",
    "pred = np.array([], dtype=np.int32)\n",
    "\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    for i, batch in enumerate(tqdm(test_loader)):\n",
    "        features = batch\n",
    "        features = features.to(device)\n",
    "\n",
    "        outputs = model(features)\n",
    "\n",
    "        _, test_pred = torch.max(outputs, 1) # 获得概率最高的类的索引\n",
    "        pred = np.concatenate((pred, test_pred.cpu().numpy()), axis=0)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "将预测结果写入CSV文件。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('prediction.csv', 'w') as f:\n",
    "    f.write('Id,Class\\n')\n",
    "    for i, y in enumerate(pred):\n",
    "        f.write('{},{}\\n'.format(i, y))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 参考文献：  \n",
    "[一文入门dropout层](https://www.cnblogs.com/geeksongs/p/13446980.html)  \n",
    "李宏毅机器学习2022在线课程"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 贡献者  \n",
    "潘笃驿(panduyi_azula@foxmail.com)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "vscode": {
   "interpreter": {
    "hash": "5da60e2a9548b1451378568c9bbd9389177ea512d36b61a0763bf3d806b69d96"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
